#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Worker thread for running LLaDA model generation.
"""

import os
import sys
import gc
import torch
import torch.nn.functional as F
import numpy as np
from PyQt6.QtCore import QThread, pyqtSignal
import logging

from config import CRITICAL_GPU_MEMORY_THRESHOLD
from utils import cleanup_gpu_memory, get_model_path, format_error

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class LLaDAWorker(QThread):
    """Worker thread for handling LLaDA generation."""
    progress = pyqtSignal(int, str, dict)
    step_update = pyqtSignal(int, list, list, list)  # step, tokens, masks, confidences
    finished = pyqtSignal(str)
    error = pyqtSignal(str)
    memory_warning = pyqtSignal(str)
    
    def __init__(self, prompt, config, parent=None):
        super().__init__(parent)
        self.prompt = prompt
        self.config = config
        self.stopped = False
        
    def stop(self):
        """Stop the generation process."""
        self.stopped = True
    
    def run(self):
        try:
            # Only import these here to avoid loading the model at startup
            from transformers import AutoTokenizer, AutoModel
            
            # Override device to use CPU
            device = 'cpu'
            self.config['device'] = device
            
            # Report progress
            self.progress.emit(5, f"Using CPU for generation...", {})
            
            # Clear CUDA cache if CUDA is available (won't hurt)
            if torch.cuda.is_available():
                cleanup_gpu_memory()
            
            # Get model path
            model_path = get_model_path()
            
            try:
                # Load tokenizer
                self.progress.emit(10, "Loading tokenizer...", {})
                tokenizer = AutoTokenizer.from_pretrained(
                    model_path, 
                    trust_remote_code=True
                )
                
                self.progress.emit(15, "Loading model on CPU (this may take a while)...", {})
                
                # Load model directly to CPU with float32 precision
                model = AutoModel.from_pretrained(
                    model_path,
                    trust_remote_code=True,
                    torch_dtype=torch.float32
                ).to(device).eval()
                
                self.progress.emit(25, "Model loaded successfully", {})
                
            except Exception as e:
                # Log error and raise
                logger.error(f"Error loading model: {e}")
                raise
            
            self.progress.emit(30, "Tokenizing input...", {})
            
            # Prepare input according to chat.py
            m = [{"role": "user", "content": self.prompt}]
            user_input = tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False)
            input_ids = tokenizer(user_input)['input_ids']
            
            # Put input tensor on CPU
            input_ids = torch.tensor(input_ids).cpu().unsqueeze(0)
            
            # Setup parameters from config
            gen_length = self.config['gen_length']
            steps = self.config['steps']
            block_length = self.config['block_length']
            temperature = self.config['temperature']
            cfg_scale = self.config['cfg_scale']
            remasking = self.config['remasking']
            
            # Reduce parameters for CPU
            original_gen_length = gen_length
            original_steps = steps
            original_block_length = block_length
            
            # Limit parameters for CPU generation
            gen_length = min(gen_length, 32)
            steps = min(steps, 32)
            block_length = min(block_length, 16)
            
            # Make sure gen_length is divisible by block_length
            gen_length = (gen_length // block_length) * block_length
            
            # Log changes if any
            if (original_gen_length != gen_length or 
                original_steps != steps or 
                original_block_length != block_length):
                self.progress.emit(32, 
                    f"Adjusted parameters for CPU mode: "
                    f"length={original_gen_length}->{gen_length}, "
                    f"steps={original_steps}->{steps}, "
                    f"block={original_block_length}->{block_length}", 
                    {}
                )
            
            self.progress.emit(40, f"Starting generation with {steps} steps on CPU...", {
                'prompt_length': input_ids.shape[1],
                'params': {
                    'gen_length': gen_length,
                    'steps': steps,
                    'block_length': block_length,
                    'device': device,
                    'temperature': temperature,
                    'cfg_scale': cfg_scale,
                    'remasking': remasking
                }
            })
            
            # Generate text
            try:
                # Import the generate function
                from generate import generate
                
                self.progress.emit(45, "Running text generation on CPU (will be slow)...", {})
                
                out = generate(
                    model, 
                    input_ids, 
                    steps=steps, 
                    gen_length=gen_length, 
                    block_length=block_length,
                    temperature=temperature, 
                    cfg_scale=cfg_scale, 
                    remasking=remasking
                )
                
                # Check if generation was stopped
                if self.stopped:
                    self.error.emit("Generation cancelled.")
                    return
                
                # Decode the output
                self.progress.emit(95, "Decoding output...", {})
                answer = tokenizer.batch_decode(out[:, input_ids.shape[1]:], skip_special_tokens=True)[0]
                
                self.progress.emit(100, "Generation complete", {'output': answer})
                self.finished.emit(answer)
            except Exception as gen_error:
                logger.error(f"Error during generation: {gen_error}")
                self.error.emit(f"Error during generation: {str(gen_error)}")
            
        except Exception as e:
            logger.error(f"Unhandled exception: {e}")
            self.error.emit(format_error(e))
            
            # Additional cleanup
            try:
                del model
            except:
                pass
            if torch.cuda.is_available():
                cleanup_gpu_memory()
